import argparse
import time
from scipy.sparse.linalg import eigsh

def parse_args():
    parser = argparse.ArgumentParser(description="Run Julia branch and bound with customizable parameters.")
    parser.add_argument('--k', type = int, default = 3, help='Specify the parameter k (default: 3)')
    parser.add_argument('--time_limit', type=int, default = 80000, help='Specify the time limit (in seconds, default: 800)')
    parser.add_argument('--filepath', type = str, default = 'Matrix_CovColon_txt', help = 'File path of the dataset')
    parser.add_argument('--d_max', type = int, default = 40, help = 'maximum size of the block allowed')
    return parser.parse_args()


import numpy as np
from scipy.linalg import sqrtm


import random

def sum_of_squares(elements):
    return sum(x**2 for x in elements)

def calculate_median_of_sums(matrix, alpha):
    d = matrix.shape[0]  # dimension of the matrix
    num_lists = int(d**alpha)  # calculate d^alpha and convert to int
    
    # Flatten the matrix and shuffle its elements
    elements = list(matrix.flatten())
    random.shuffle(elements)
    
    # Split elements into num_lists parts
    lists = [elements[i::num_lists] for i in range(num_lists)]
    
    # Calculate sum of squares for each list
    mean_sum_squares = [sum_of_squares(lst) / len(lst) for lst in lists]
    
    # Calculate and return the median of sum_squares
    return np.median(mean_sum_squares)

def threshold_matrix(matrix, threshold):
    """ Thresholds the matrix by setting values below the threshold (in absolute terms) to zero. """
    return np.where(np.abs(matrix) < threshold, 0, matrix)

def find_supp(x, ind):
    # Identify the non-zero elements of best_x
    non_zero_indices = np.nonzero(x)[0]
    
    # Intersect the indices where best_x is non-zero with best_ind
    support_indices = [ind[index] for index in non_zero_indices]
    
    return support_indices

def truncation(x, k):
    # This function finds out the max k absolute entries of x, and zero out all
    # other entries, and then scale the norm to one
    d = x.shape[0];
    # we use y here because we don't want to mess up the original copy of S
    y = np.zeros(d);
    x_abs = np.abs(x);
    # Here axis = 0 because we have a column vector
    mask = np.zeros_like(x_abs, dtype=bool)
    mask[np.argsort(x_abs, axis = 0)[::-1][:k]] = True
    #ind = np.argsort(x_abs, axis = 0);
    #zero_ind = ind[:d-k];
    #x[zero_ind] = 0;
    y[mask] = x[mask];
    y_norm = np.linalg.norm(y,2);
    if y_norm:
        y = y / y_norm;
    return y;

def find_max_eigenvec(S):
    # This function finds the maximum eigenvector of S
    eigenvalues, eigenvectors = np.linalg.eig(S)
    index_max_eigenvalue = np.argmax(eigenvalues)
    max_eigenvector = eigenvectors[:, index_max_eigenvalue]
    return max_eigenvector;

def find_truncated_PC(S, k):
    y = find_max_eigenvec(S);
    z = truncation(y.flatten(), k);
    return z, z.T @ S @ z;

def find_truncated_column(S, k):
    max_value = 0;
    d = S.shape[0];
    max_column = np.zeros((d,1));
    for i in range(d):
        z = truncation(S[:,i], k);
        value = z.T @ S @ z;
        if value > max_value:
            max_value = value;
            max_column = z;
    
    max_value_basis = 0
    for j in range(d):
        # Finds the maximum for standard bases
        value = S[j,j]
        if value > max_value_basis:
            max_value_basis = value
            max_basis = np.zeros((d,1))
            max_basis[j] = 1
    
    if max_value > max_value_basis:
        return max_column, max_value
    else:
        return max_basis, max_value_basis

def chan_app_alg(S, k):
    # Measure time for truncated column search
    start_time_TS = time.time()
    vec_TS, value_TS = find_truncated_column(S, k)
    end_time_TS = time.time()
    time_TS = end_time_TS - start_time_TS
    print(f"TS found. The value is {value_TS}, time taken: {time_TS} seconds")

    # Measure time for truncated PC search
    start_time_TE = time.time()
    vec_TE, value_TE = find_truncated_PC(S, k)
    end_time_TE = time.time()
    time_TE = end_time_TE - start_time_TE
    print(f"TE found. The value is {value_TE}, time taken: {time_TE} seconds")
    
    # Total runtime
    total_runtime = time_TS + time_TE

    # Compare the results and output the larger vector, value, and total computing time
    if value_TS > value_TE:
        larger_vec, larger_value = vec_TS, value_TS
    else:
        larger_vec, larger_value = vec_TE, value_TE

    return larger_vec, larger_value, total_runtime

def find_block_diagonals(A, matrix):
    # we should modify the code here
    def dfs(node, visited, component):
        stack = [node]
        while stack:
            v = stack.pop()
            if not visited[v]:
                visited[v] = True
                component.append(v)
                for neighbor in range(d):
                    if matrix[v, neighbor] != 0 and not visited[neighbor]:
                        stack.append(neighbor)

    d = matrix.shape[1]
    visited = [False] * d
    components = []

    for i in range(d):
        if not visited[i]:
            component = []
            dfs(i, visited, component)
            components.append(component)

    buckets = {}
    for component in components:
        if component:
            root = component[0]
            buckets[root] = component

    block_diagonals = []
    indices = []
    d_star = 0

    for bucket in buckets.values():
        matrix_bucket = A[np.ix_(bucket, bucket)]
        if np.all(matrix_bucket == 0):
            continue
        block_diagonals.append(matrix_bucket)
        indices.append(bucket)
        d_star = max(d_star, len(bucket))

    return block_diagonals, indices, d_star


def solve_bd_spca(A, k, block_diagonals, indices, time_limit):
    time_total = 0
    obj_best = 0
    ind_best = [0]
    x_best = np.zeros((1, 1))

    for i in range(len(block_diagonals)):
        bd = block_diagonals[i]
        if np.all(bd == 0):
            continue
        if bd.shape[0] < k:
            start_time = time.time()
            max_eigenvalue, max_eigenvector = eigsh(bd, 1, which = 'LM');
            if max_eigenvalue > obj_best:
                obj_best = max_eigenvalue
                ind_best = indices[i]
                x_best = max_eigenvector
            block_time = time.time() - start_time
            time_total += block_time
            continue
        
        # Using chan_app_alg 
        start_time = time.time()
        try:
            x_best_temp, obj, total_runtime = chan_app_alg(bd, k)
            if obj > obj_best:
                obj_best = obj
                ind_best = indices[i]
                x_best = x_best_temp
        except Exception as e:
            print("An error occurred:", str(e))
        
        block_time = time.time() - start_time
        time_total += block_time

    start_time = time.time()
    original_obj = x_best.T @ A[np.ix_(ind_best, ind_best)] @ x_best
    time_total += time.time() - start_time

    return x_best, ind_best, obj_best, original_obj, time_total


def solve_bd_spca_bs(A, k, initial_threshold, a = 0.1, b = 10, max_d = 100, tol = 5e-2, time_limit = 600):
    # In this function, we call solve_bd_spca many times
    start_time = time.time()
    total_time = 0
    U = b * initial_threshold
    L = a * initial_threshold
    
    S = threshold_matrix(A, U)
    block_diagonals, indices, d_star = find_block_diagonals(A, S)
    
    print("Now solving the first spca instance.")
    best_x, best_ind, temp_obj, best_obj, time_passed = solve_bd_spca(A, k, block_diagonals, indices, max(time_limit - time.time() + start_time, 0))
    
    i = 2
    # record the max d_star visited within computational constraints
    max_d_star = 0
    while U - L > tol:
        if max_d_star == max_d:
            # meaning that we have already reached the computational limit in previous attempt
            # We should stop immediately
            break
        
        if d_star <= max_d:
            max_d_star = max(d_star, max_d_star)
        # d_star_old = d_star
        best_obj_old = best_obj
        
        M = (U + L) / 2
        start_sorting_time = time.time()
        S = threshold_matrix(A, M)
        block_diagonals, indices, d_star = find_block_diagonals(A, S)
        
        if d_star <= max_d_star:
            # results are the same or worse
            # it should be d_star >= d_star_old
            U = M
            print(f"Current threshold is {M}, and gives the same d_star.\n")
            continue
        
        if d_star > max_d:
            # We cannot afford such computation
            L = M
            print(f"Current threshold is {M}, d_star is {d_star}, and exceeds computational resource.\n")
            continue
        
        print(f"Now running block diagonal spca for threshold {M}, with d_star being {d_star}.")
        print(f"This is the {i}-th spca instance.")
        i = i + 1
        
        sorting_time = time.time() - start_sorting_time
        
        # Else, we know that we can afford the computation, and the result is going to be potentially better
        best_x, best_ind, temp_obj, best_obj, time_passed = solve_bd_spca(A, k, block_diagonals, indices, time_limit)
        print(f"Best obj found is {best_obj}. The runtime for this instance is {time_passed + sorting_time}.")
        print(f"Best index set found is {best_ind}.")
        
        supp = find_supp(best_x, best_ind)
        print(f'Current support is {supp}.')
        if supp:
            D1, V1 = np.linalg.eigh(A[supp][:,supp]);
            y1 = V1[:, -1];
            better_PC_value = y1.T @ A[supp][:,supp] @ y1;
            print(f"Better obj found is {better_PC_value}.")
        
        total_time = time.time() - start_time
        print(f"The total runtime is {total_time}.")
        
        if total_time >= time_limit:
            print("Time limit reached.\n")
            break
        
        if abs(best_obj - best_obj_old) < 1e-2:
            print("Unchanged objective value detected.")
            U = M
        print("\n")
    
    total_time = time.time() - start_time
    print(f"Best opt found is {best_obj}.")
    print(f"Total runtime is {total_time}.")
    print("\n")
    
    return best_x, best_ind, temp_obj, best_obj, total_time, M
    


args = parse_args()

file_path = args.filepath

A = np.genfromtxt(file_path, delimiter=',')
n, d = A.shape

if n != d:
    # we are not getting a square matrix, thus transformation is needed
    A = np.cov(A.T)

A_inf_norm = np.max(np.abs(A))

print("Data loaded from txt and processed.")

print("Now starting to solve the original SPCA using chan_app_alg.")

# Using chan_app_alg directly
k = args.k
time_limit = args.time_limit

try:
    vec, obj_value, total_runtime = chan_app_alg(A, k)
    spca_obj, spca_xVal = obj_value, vec
    spca_timetoBound = total_runtime

    # Print or process the results
    print("Objective value:", spca_obj)
    print("Time to bound (total runtime):", spca_timetoBound)


except Exception as e:
    print("An error occurred:", str(e))

print('\n\n\n')

print("Now running block diagonal method:")

initial_threshold = 1
d_max = args.d_max

best_x, best_ind, temp_obj, best_obj, total_time, current_threshold = solve_bd_spca_bs(
    A, k, initial_threshold, a=0, b=A_inf_norm, max_d=d_max, tol=1e-2 * A_inf_norm, time_limit=time_limit)

print(f"Best obj found is {best_obj}.")
print(f"Instance size is {len(best_ind)}.")
print(f"Total runtime is {total_time}.")
print(f"Best threshold found is {current_threshold}.")

    
    






